Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c85ce08e94
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| bucket_state = states[indices[0]].setdefault(bucket_key, {}) | ||
| for i in indices[1:]: | ||
| states[i][bucket_key] = bucket_state |
There was a problem hiding this comment.
Don't reuse bucket state when the active set changes
When same-shaped parameters share bucket_state here, the tensors inside that state are sized by the current bucket's leading dimension. Since _step_inner builds param from split_p_and_g_in_group(..., skip_none=True), any parameter with grad is None drops out of the bucket for that step; reusing the old shared state then either shifts later parameters onto the wrong slot or leaves exp_avg/GG with a leading dimension from a previous active set, causing shape errors in the SOAP/PSGD updates. This affects normal models with intermittently unused parameters, so the bucket key/state needs to be stable for the exact active members or handle resizing/masking.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7b4a4652d9
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| for i in indices[1:]: | ||
| states[i][bucket_key] = bucket_state | ||
|
|
||
| result = self.fn([bucket_state], group, [slab_u], [slab_g], [slab_p], *args, **kwargs) |
There was a problem hiding this comment.
Preserve HVP temporaries when bucketing PSGD
When PSGD/PSGDPRO are used with hessian_approx=True and same-shaped parameters are bucketed, the HVP closure stores vector/hessian_vector temporaries under each original parameter view, but this call passes a freshly stacked slab_p into the inner PSGD chain. _update_psgd_precond() later looks up get_temporary(group, param) using that slab tensor identity, so the lookup misses and it silently falls back to dampen_grad(...) instead of using the computed HVPs, fitting the wrong preconditioner for these runs.
Useful? React with 👍 / 👎.
No description provided.